#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/profiler.cuh>

using namespace flashinfer;

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

#ifdef FLASHINFER_ENABLE_PROFILER
#define PROFILER_PARAMS_SETTER \
  params[i].profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
#else
#define PROFILER_PARAMS_SETTER
#endif

{{ variant_decl }}

template <bool UseLogitsSoftCap>
struct StandardAttention : AttentionVariantBase {
  float sm_scale_log2;
  float soft_cap_pre_tanh_scale;
  static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
  PROFILER_CLOSURE_PARAMS_DECL

  template <typename Params>
  __device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
                                        uint8_t* smem_ptr) {
    if constexpr (UseLogitsSoftCap) {
      soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
      sm_scale_log2 = math::log2e * params.logits_soft_cap;
    }else{
      sm_scale_log2 = params.sm_scale * math::log2e;
    }
  }
  REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
    if constexpr (UseLogitsSoftCap) {
      logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
    }
    return logits;
  })
};

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
  DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
    using AttentionVariant = {{ variant_name }}; \
    __VA_ARGS__(); \
  })

using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};

constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};

struct PersistentParams {
  using DTypeQ = DTypeQ;
  using DTypeKV = DTypeKV;
  using DTypeO = DTypeO;
  using IdType = IdType;

  DTypeQ* q;
  DTypeKV* k;
  DTypeKV* v;
  DTypeO* o;
  DTypeO* partial_o;
  float* partial_lse;
  DTypeO* final_o;
  float* final_lse;

  IdType* q_indptr;
  IdType* kv_indptr;
  IdType* partial_indptr;
  IdType* kv_indices;
  IdType* q_len;
  IdType* kv_len;
  IdType* q_start;
  IdType* kv_start;
  IdType* kv_end;
  IdType* kv_head_idx_arr;
  IdType* work_indptr;
  IdType* len_kv_chunk;

  // for state reduction
  IdType* merge_indptr;
  IdType* merge_o_indices;
  IdType* num_packed_qo_len;

  uint32_t num_kv_heads;
  uint_fastdiv gqa_group_size;
  uint_fastdiv page_size;

  uint32_t q_stride_n;
  uint32_t q_stride_h;
  uint32_t k_stride_page;
  uint32_t k_stride_h;
  uint32_t k_stride_n;
  uint32_t v_stride_page;
  uint32_t v_stride_h;
  uint32_t v_stride_n;

  float sm_scale;
  double logits_soft_cap;
  {{ additional_params_decl }}

  PROFILER_PARAMS_DECL
};
